{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Action1_pytorch_cifar10_accuracy%96"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/",
     "height": 34
    },
    "id": "2TQg-EeiGSNs",
    "outputId": "272cb183-702d-4ae4-926b-d377f3be89e0"
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Mounted at /content/drive\n"
     ]
    }
   ],
   "source": [
    "# 挂载colab，路径定位\n",
    "from google.colab import drive\n",
    "drive.mount('/content/drive')\n",
    "import os\n",
    "os.chdir(\"/content/drive/My Drive/Colab Notebooks/Bi-IV/CNN\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {
    "id": "9ArUMpK1FYbB"
   },
   "outputs": [],
   "source": [
    "import torch\n",
    "import torch.nn as nn\n",
    "import torch.optim as optim\n",
    "import torchvision\n",
    "import torchvision.datasets as datasets\n",
    "import torchvision.transforms as transforms\n",
    "import matplotlib.pyplot as plt\n",
    "import time\n",
    "from torch.utils.data import DataLoader"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/",
     "height": 35
    },
    "id": "2ajTInVMGRqY",
    "outputId": "14117bd6-6e49-40ad-9dc6-9dde161c7db3"
   },
   "outputs": [
    {
     "data": {
      "application/vnd.google.colaboratory.intrinsic+json": {
       "type": "string"
      },
      "text/plain": [
       "'/content/drive/My Drive/Colab Notebooks/Bi-IV/CNN'"
      ]
     },
     "execution_count": 6,
     "metadata": {
      "tags": []
     },
     "output_type": "execute_result"
    }
   ],
   "source": [
    "pwd"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {
    "id": "_iTF22QUFYbH"
   },
   "outputs": [],
   "source": [
    "# 数据加载\n",
    "train_data = datasets.CIFAR10(root='./',train=True, transform=transforms.ToTensor()) #  download=True\n",
    "test_data = datasets.CIFAR10(root='./',train=False, transform=transforms.ToTensor())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 90,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "((50000, 32, 32, 3), (10000, 32, 32, 3))"
      ]
     },
     "execution_count": 90,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# 50000训练图片，10000测试图片\n",
    "train_data.data.shape,test_data.data.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 95,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "<matplotlib.image.AxesImage at 0x7f2679c21410>"
      ]
     },
     "execution_count": 95,
     "metadata": {},
     "output_type": "execute_result"
    },
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAPgAAAD2CAYAAAD720p7AAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4xLjMsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy+AADFEAAAeA0lEQVR4nO2dW2xc15Wm/1V33q+iREnU1ZIiW7LjhNbYjiZO3GO7u8dJd9rTAw8QdNB58EsHefAgT4MA/ZKHvDQQID3BGJgBMp6g52a423GQTvs26TiOI8jtSJZkWfcLJZGSKJFFsu5Vex5Ez8jS/rcomsXi7Pk/gBC1F3edVbvOqlO1/7PWMucchBBxkmi1A0KI5qEAFyJiFOBCRIwCXIiIUYALETGpZh+gq7vHDQyt9toqpQKdV6uUvOPOGZ2TzuSoLZPltmQ6Q22JhP94peIsnVMpF6nN1evUZuDPLZFM8nkJ//t0R2cXnZMNrIer16itWOSvGeBXZBquQWeUinyt6gE/QuoPM9Vq3I9GI/R4fF4qxUMoleKvmYP/PAiJWg3uBqan8ledc6tu84FP8WNmOQD/E8AIgIMA/swFVntgaDX+3V/9e69t7Oh79DhXTn/oHa/XucurN3yG2jZs3UltfWs2UFuuzX+8Y4ffoXPOnjhIbdUZ/saQDDy37r4eakvl2r3je77wRTrnnu18rUrT16jt8KH3qa3RqHjHK1X/mzUAHDn8AbXlp65SW7lSprZqxR9Y1yb5m9NsgftYq/NjrVrVT219/Z3UVncz/mNV6RSUijz6/+7lfzjrG1/MR/SvAxhzzj0AoA/AE4t4DCHEMrCYAH8cwGvzv78J4MtL544QYilZTIAPAJie/z0P4LbPKGb2nJntN7P9M/npW81CiGViMQF+FcDHXwh75v//CZxzLzjnRp1zo13d/LujEKK5LCbA3wDw5PzvjwN4a+ncEUIsJYuRyX4C4E/M7CCAA7gR8JR6vY78df+u7EAv34F0q/zSmkt10znDG7ZwPxp8ezLR4LurjYJfqildn6RzXJHvyK4bHKK2DSP3UNvIPRupbe269d7xISJPAkA6naW2Wq9/Vx4ARtav4fNq/l30UolLYVPXuapw9SrfzU8FJFGYfxe9b4A/51wH93E6f53asjkeQg3HZb50yu9LfnqKzqmU7z4x7K4D3DlXBvD0XR9JCLHs6E42ISJGAS5ExCjAhYgYBbgQEaMAFyJimp5NBueAql+iqpS5dFUo+CWXTdvX0Tmzc3PUFkp46B8MJHKk/e+B27Ztp3MefXiU2tat9ktaANDTc1sy0P+hmuJZaO05v+SSCqgqVgtkjM1x6apMXksAaG/zy2t9vVwa3LrlXmr78MOPqA3G/SiX/bJnT3cfnRNIKMR0foLaHPznKRDOULt+3X+uFgs8sWUx5RN1BRciYhTgQkSMAlyIiFGACxExCnAhIqbpu+iu0UCNJBtYje8MZzNt3vHpq7yMz8AavkO94T6eyDE0spba0mx7NVBbp1rjO/ZHL/EklcKpK/wxE3y39qMPDnjHH9rJd6i/uOchagvVO8sH8vvPnb3oHc+kA7XyMjx5aHAVV0zOnT/OH5OUsJotcpUln+fnVSrNa+V1d/PEnFD9OlZuLlQ3LpsNbPUTdAUXImIU4EJEjAJciIhRgAsRMQpwISJGAS5ExCyLTFYu+OWJzjYun3T3+xMvPvfAZ+mckS3bqG0mkFzx0anz1JYv+KWO2SleO2tyikthl8Z5fa/uQLIJEjwJ4dX/9pJ3PP2v+fv3Y4/spbZ0mkuAa9ZwSRHOLzVNXfd38QCAf3qfd4FJBerGdXRxea1W98t8lVn+miUDl7pQ95J6ncuXk9e49JaAX14LtULq7b37CsW6ggsRMQpwISJGAS5ExCjAhYgYBbgQEaMAFyJimi6TWcKQzaa9tmqyi84rtvmbp5/O8xYzv3t7H7Vdm+R1xi5c5DW30kl/JlE6wbN+yqSFDwCUStw2vIq/HJfHvf3dAQDdJMtoZipP5xw7fZr7MTxIbek093F4xN/WaC0ZB4Bz41yi/OgDbhsa5pLimXNEnqry16xR4bZ6oB5eLsOlvGzKf94DQLHkf8zubi7/pUi7oxCLuoKb2UNmNmZmb8//7FjM4wghmstir+B9AH7knPveUjojhFhaFvsdvA/AM2a2z8xeMjOeES+EaBmLDfATAL7rnNsDYBjAYzcbzew5M9tvZvvnZvl3XyFEc1lsgJ8B8PpNv3+isr1z7gXn3KhzbrSj079ZJoRoPosN8OcBPGtmCQC7ABxaOpeEEEvFYjfZfgjgbwB8C8DLzrkj7A8TiRTa21d7bZeneIbXifN+ieTIYf5ekghIOPVAm6TiDC/GlyRyWLHMJaipGW6bCbQFOjP2IbV1tHFJccdWImIE5Lpf/+p/UdvGzZupbfsO3rJpYMCf7ZTN8delp5tLP4kaL/A4V+bXJtb+pzjFs9rqdV4oM9fG5a7ZPH/M7kDGWzaX9I5XKqF2XryII2NRAe6cuwTgS4uZK4RYPnQnmxARowAXImIU4EJEjAJciIhRgAsRMU3PJksmU+jt92cnnTh/jM67dMaf7dSe5sUHp+d4QcPZ/GVqswbPJJqa8ctaU0Uuq6RI9hwADK4eora2Ll5Ub92mB6hthEgupw/8hs5JGpfQqnWePXXlKi8ouXv3Tu/4Pdu20DkjgaywzocfpLaDR89RW7nkL+ZZTgeyycAlrYbjcu74uL8fGwBkslwC7Olj5wGXbItFnknJ0BVciIhRgAsRMQpwISJGAS5ExCjAhYiYpu+il8tzOHnSXyvt6MkTdN7FSye94/VAYkhXTwe17di2idp27dxFbZeu+Hcuz17hfqxa40+uAYCNW3kiR9cA32GfuM6P5676FYdzZ/lO85VAe6Wd91ITntju3ykHgLlZ/1o1+KY8XIXv5h9+l6sA23bwFlar1/V6x9/d9490zvgETxCqVvkueqnI/b8eaNnU1un3seH4Tv8caQEWQldwISJGAS5ExCjAhYgYBbgQEaMAFyJiFOBCREzTZbK52Tze/cfX/AdfzRuibN252zveFmgxs/PebdS2Y/t6aquX/MkaAOASfulnDqQ9DoBU2p/sAADJpF8eAYBqjScnzM1co7aeil/GqdUdnXPuMk/MyXVe4Mfq7qO2LVs3ecdd4DpSnOJ1xo7+9nfU5or8PNj11O97x3ffz5Neivu5THbyxBlqa2/nVYN7egeoDfBrh/k8f13K5buvyaYruBARowAXImIU4EJEjAJciIhRgAsRMQpwISKm6TJZtVLD5fN+SenBB/4lnZfN+mt19XNFC8NreV2ta4G2NedPcAmq0vBLVwnjKVLJFJdw6o7XlEMt1HqJ1+Nydf/xOnv8tfAAYHKWZyYlMjwrr+G49AYQG18OdOb4a7Zp7Qi15ZLcjwT8dfR27+KZfL29XL58pfgP1DZ+icta64bWUlvd/DX90oH2W/k8l/IAf9urBV3BzSxtZj+d/z1nZq+a2QEze1G9wYVYudwxwM2sDcB7AJ6YH/o6gDHn3AMA+m4aF0KsMO4Y4M65onPufgBj80OPA/j41rQ3AXy5Sb4JIT4li9lkGwDwcV/XPID+W//AzJ4zs/1mtr9W49UwhBDNZTEBfhXAxxX6e+b//wmccy8450adc6OpVNP38YQQhMUE+BsAnpz//XEAby2dO0KIpWQxl9efAPgTMzsI4ABuBDwlkUihvfO2T/EAgHRAcZma8rcayvZzOaNQ43pMiXcaQltfF7VlG0QkKHGZzAVWtVTlGUG5Nj4xEWg11Ej453UOcJkm47g0mGzjGWMuw3XKhvmfm9W57JZI8uec7shQW1snt9XKfkl08sIEnTPQwVso/dEfPkVt+w+cobbZQEHGUvmKd7wcaE/U28XPfcaCA9w5d8/8v2UAT9/1kYQQy47uZBMiYhTgQkSMAlyIiFGACxExCnAhIqbpd6FkMlkMb/Bn8ViCv7+USv7MmYk8dznTy7OnqjUuq1g6TW3FWX9mUtVx31MpXjyxluS29m6eWTU0MEVt7ppfWqkEempZg/vf1tZGbYlANl/D+Y9Xr3NJMZEOFLxMch9n53h2oDX8cmk2cL7lr3AJra3dL/MCwBcfuZ/aPjp5ltoOHRn3js/meZZfJlDMk6EruBARowAXImIU4EJEjAJciIhRgAsRMQpwISKm6TKZM8CZXwqpBmScwoxfBskGJJyZfKB4YokXOyzkueSSJslkXR1c7lrVx2WV7n6eWbWqlz+3eqqH2opZ/zpe28izycr1S9SGQMZbvRbIaiOZd/UEz/KzgEzW28+z2hr1gI/kvOrp4eubMZ7aODUTkCirfhkVAD67cw219Xb5z59XX+UFHq9M8H54DF3BhYgYBbgQEaMAFyJiFOBCRIwCXIiIaX7JU+cAsvOaavAd2R5yX/1ID2+k8pktvGZVZ47voCaNv8/N5f07qKXCtHccANo6qtS2YxvfYR/ZuJ7aEumN1DY75fdxZHiY+3HaX/MOALr7eVJDfx9PiEml/Ak9jUDtPRdIXsl1tFNbrcQVmAQ5XjqU3ASusgwMdlLbbIHv5s9N+RNKAGDdKn8NuD/+ypPecQD425+9Tm0MXcGFiBgFuBARowAXImIU4EJEjAJciIhRgAsRMU2Xybo62vHYI5/32rbc+wCdd/HCBe/4urVcZtq+bSu1rVk1RG1Jx6W3GZJoUA4kZFiCP15nB0826ezk8lQyw2W+NJEbi3P+9jgA8LldXHbbtH0TtVUbXAJ05HpRa3BJyyX5WiXT/PSslrj21iDJJokUv55ZjvuBwLxyla9HKslr/dUr/vNqVUCS2/vPH6K2//7Sa97xBV3BzSxtZj+d//0hMxszs7fnf3Ys5DGEEMvPHa/gZtYG4LcAts8P9QH4kXPue810TAjx6bnjFdw5V3TO3Q9gbH6oD8AzZrbPzF4ys8BnGyFEK1nMJtsJAN91zu0BMAzgsVv/wMyeM7P9ZrZ/do4nxAshmstiAvwMgNdv+v223Svn3AvOuVHn3GhnB980EEI0l8UE+PMAnjWzBIBdAA4trUtCiKViMTLZDwH8DYBvAXjZOXck9Mft7W34/P2f8drue5DLZMVdfsmro4dnM/HKX4ALbBUkAnJGf4e/rlagc1HwXbNB2uoAQC1Qow4BOaZc9rcu2nrPBjqnLcPluuIcz5RzicApY36bC9Q7azhuqwdes0YgRa1S9K9HvcGfcyIVOD8Cr+jMJJdLz54+T21f2Pugd7xQ5fUB20NSHmHBAe6cu2f+30sAvnTXRxJCLDu6k02IiFGACxExCnAhIkYBLkTEKMCFiJimZ5MlEgm0kQyqzhxv/9PRTlxL8Sp9oeJ+oTtqEyE5xvllrUaVy10h6ccChf9qAaEvkKAGR4pGdvbyzLtanR+r3ghUQiTtiQDAoe4dT4Scr3NbPcXlS4fAi02KfFrD7x8AZAPPOV3nr1lHic9zE365DgCunJrwjq/fwQtvXk3c/V2huoILETEKcCEiRgEuRMQowIWIGAW4EBGjABciYpoukyWTSXT1+OUaF8jiKpT9Uocr8x5SZTIHAOZm56itUuXzymV/FletxmWmaiDzqxo4ViHQ56owx7OMaiRDrau/h87p6uF93Hq7Bqktl/H3HwOAOus1Z4E+YuC2ri5ehHLyMl/HUtEvJzUafXSOgT+vRp2fc91dXOrduGE1tRUL/vPRBQpU9nTxbDiGruBCRIwCXIiIUYALETEKcCEiRgEuRMQ0fRd9aiqPv33l515bPf0rOu/6df/N+LPTV+mcRCD/ILTDPjHhPxYA1EkGS3+gFVLf4AC1ZZN8yeeu+dvZAMCx4x9SW37Wv2s8spm3J0qmuYLR3cX937yZ13lbP+KvX7d5yzo6pz/Lk026ctzHRqA2H5L+BJBqne9QJwPtiZIBH1dvCigO3XyHver8iS9JvpmP/v7AcyboCi5ExCjAhYgYBbgQEaMAFyJiFOBCRIwCXIiIabpMlp+ZxWtvveO19a7fQee5ul/6ef+dt+icjet5PavBAS79XBgbp7YaqePV3s+TNSoJnogyMcbb2fzenkeo7bP330dthXLJO55I85f39Lmz1Hbs+Elq++DQ+9TW2+NvNPnMv/oanfOF+7ZTWybQH2r98Ai1VYhMZoHacKE6elVSaw4AEqlAnbdenizTRmrzNZJczuWiIWdBV3Az+7GZvWtmr5hZp5m9amYHzOxF9QcXYuVyxwA3s70AUs65hwF0A/gmgDHn3AMA+gA80VwXhRCLZSFX8AkAP7jp7/8SwGvz/38TwJeX3i0hxFJwx+/gzrnjAGBmX8ONDr3vA/i4v2wewG1fpM3sOQDPAUAu175Uvgoh7pKFfgf/KoBvA/gKgHEAH5cK6QFw283hzrkXnHOjzrnRTIbfjyuEaC4L+Q6+BsB3ADztnJsB8AaAJ+fNjwPg29pCiJayEJnsGwCGAfxifsP8RQDrzOwggAO4EfCUvv4B/Om/+TOvLTu0jc4rzPilq+MfHKBzhtdw6SQRaBnUluNZOpWGv/3M9l3c975hnmlWGOR1wZ7+g39Bbe1dbdQ2R2SyQJch1EhLJgAo1fyPBwCXL1+jtrOnL3rH29v5+o6PTVLbmcPHqS1R4j6eGr/sHd/z5Cids3HTWmoLZaElcoH0rzSX0IzVXjM+J2P8NWMs5Dv49wF8/5bh/3DXRxJCLDu6k02IiFGACxExCnAhIkYBLkTEKMCFiJimZ5OZAdmM/33k2NFDdF5+2i+TuVDWT4Vn4swGWheF8mVyWX8OT7XAWwlNX+E+Tpzj2WQ//4W/OCUAXJ8JHG922jve1c3lqZ4+fzspAOgIFAscG/NLYQAwNOgvrpjr5rLhr37Gn/O14weprV7h7aFOjPuLaI4F2j9t28llz55ufjdmTx9vD9XWzrPJejr851U658+EA4D29ru/aUxXcCEiRgEuRMQowIWIGAW4EBGjABciYhTgQkRM02WyRq2KmUm/5PXm3/2Mzjs/PuYdT1T92V0AcPBgnjsSkMJqNZ4tBJLB89qrb9IpmTSXMz774OeorZLporZ8uUBtp875s6cmJ3k/s0qJZyZdHD9DbafP8MccffDz3vFv/8XzdM6+d39DbbVpnmmWL5eprQi/THlqP5cof/XeJWrrSHFJLp3hslYyy8+DLiKTrd+4ic75o2eepTaGruBCRIwCXIiIUYALETEKcCEiRgEuRMQ0fRc9nc5gePWw17Zt02Y6z8G/y5sKtAVKBnbKE0n+XuYaPDkkk+vwG9I8kWDtWn/SBQB86amnqK2rPZDUkOO13I4c8tepO3aCtyBas24TtZUCLYOSbdzHQ8eOesePHDtG57Rv2kltFy/y59zXy21DGX+dtPZOXtfu2jhv5TR54QS1XbnqT2wBgFI9kBhFCuZdmuIh+ejv3X0TIV3BhYgYBbgQEaMAFyJiFOBCRIwCXIiIUYALETFNl8lqtRquXfG3u3n4nz1K5z362GPe8WyW39yfCkhhodZFjUAbnyT8x6tWeIuZYoUnhkyOnaa2ayWe1HDtKm8ZdIrIYRcv+5N8AKBziLfqQZZLgJbhMlml5k8Aee2Xb9M5G7fupraRfi435hL81G0nyT7lEq/Jdip/mNo6u3htu7rjiUrj12epbXBwk3e8UOXn4pu/3EdtjIV2F/2xmb1rZq+Y2UNmNmZmb8//3NY+WAixMlhId9G9AFLOuYcBdONGI8IfOef2zv981GwnhRCLYyFX8AkAP7jp7/sAPGNm+8zsJQvVHBZCtJQ7Brhz7rhzbp+ZfQ1AA8BRAN91zu3Bjav5bV+Wzew5M9tvZvtnZvn3HiFEc1nod/CvAvg2gK8AOAHg9XnTGQC3VbV3zr3gnBt1zo12dfIqJUKI5rKQ7+BrAHwHwNPOuRkAzwN41swSAHYB4O1JhBAtZSEy2Tdw46P4L+a/bv8cwJ8D+BaAl51zR0KTEwlDB2m5Mpkv0XnvH3zPOz40xLOIVg8NUlu1yiWo69enqA0lv4+pBn+8dZu5BDXSxz/RXDjG64LNzfIaZEOr13jH2wd66Zxkjks/hSJ/XYaHN1Db+EV/Hb2rk/7WSgAwvDbQUirQpmq2zNcfKf/5Vm1waTPbRrIGAWQD20yVySvcj4S/7hoArCbZfJUyb78VWA7KHQPcOfd9AN+/Zfh7d38oIcRyozvZhIgYBbgQEaMAFyJiFOBCRIwCXIiIaXo2WcKAbNqfIVMucXnqnXfe8I67Kpdwutt5Ub1qlWf9lIq8HVKKvAdu3DRC5+x6+F5q27qBS2hT5/0yEwCMX79KbZk2vyy0dcAvnwHAlSs802n3jl3Udt9unlv0X//Lf/aOp+AvgggA1Tn+elYq3OZqXPJCzv9ah1oJbdq8hdounw+kWyR4dmNbBz/ezp3bveOlAn9dRoZvu6fsjugKLkTEKMCFiBgFuBARowAXImIU4EJEjAJciIhpukzWaDRQKJIihIFCiE/9wdP+x6vw7KNkQApr1HkxO5fkUkcy5Zd4ch28+OD4FJfdZqZ4n65rRe6/5XghxI9+d8o7Pvkbnum0ZTOXux66Zxu1VQKZZm0ZvyzkApl8ocy1RJKfnqS1FwCg2CB97ep8fTeu5zJZaXaS2u7t5llo+957n9ounvVLb8U5fn67wnVqY+gKLkTEKMCFiBgFuBARowAXImIU4EJEjAJciIhpfjZZwtDR6ZeaegJF5LpW+bNtymVefDAXeL/KGM9ocm08Cy3b7p/XKPGsn5mZPLUl23mxw6GtvEji1naeTXb8tL83GYzLf2lSCBMALlw6R20Dg7zoJbNVilz6KZd5Qca5QKZZOZB1VS37ZdlUjkubq9euorazlyaobeIcWXsApVn+3E4e/p13fGCA++H6+qmNoSu4EBGjABciYhTgQkSMAlyIiFGACxExy5BsUkJhhiRYNPj7S9o6veMTE3xn8viRM9SWS/Gd8kwP370eJK2S1g720DmpQBLNQM8AtQXyYVAq8kSDoSH/zvy6tXzX9dL4OLUdO/YhtW2qbKY2pnDMzPDXrFDgO9T5aa5GhHbR6xV/sk8yyxNDDh/iba9C7YSGhlZT27r7eW27oVX+eYOreB29XMB/xkKaD6bM7H+Y2a/N7D+ZWc7MXjWzA2b2ovqDC7FyWchH9D8GcMA59wXcaEL4LQBjzrkHAPQBeKKJ/gkhPgULCfC/B/BXZpYC0AvgcwBem7e9CeDLTfJNCPEpuWOAO+dmnXMFAL8GMAFgAMDHX6ryAG77omdmz5nZfjPbPzNDij0IIZrOQr6DD5hZFsCjuPGRfBeAj3eYegDcdg+lc+4F59yoc260q4vfHiiEaC4L+Yj+bwH8qXOuDqCAG73Bn5y3PQ7grSb5JoT4lCxEJvtrAC+a2V8AOAngPwJ4ycwOAjgAwN9j6GMaDg3SgiYReH9JVf2JEt2kDRIAvPfuL6ltfIIna1iaJ17s2fN57/jeR0bpnOlpLgsd/KffUttciSdXHDt3ntpOnTnjHS8W+Ncj57j4kevmCQ/5/Ay1zZD2SnN5LvGFJJhUklt7Ap8M1272S3l9A8N0ztBaLk+tfXA3tfUHarJlQrX+mC2QIAR397et3DHAnXMXcONKfTP+iohCiBWF7mQTImIU4EJEjAJciIhRgAsRMQpwISLGnAsURluKA5hdAXD2pqFBeG6OaQHy45OsBD9Wgg/A/5t+bHTO3aZvNj3Abzug2X7nHBeR5cf/t36sBB9i80Mf0YWIGAW4EBHTigB/oQXH9CE/PslK8GMl+ABE5MeyfwcXQiwf+oguRMQowIWImGUJ8JVSqNHMHjKzMTN7e/5nR4v8SJvZT+d/b8na3OJDy9bFzH5sZu+a2Stm1tmq8+QWP1qyHs0ocLpcV/CvY2UUauwD8CPn3N75n4+W2wEzawPwHv7vGiz72nh8aMm6mNleACnn3MMAugF8Ey04Tzx+DKM158mSFzhdrgB/HCujUGMfgGfMbJ+ZvdSKTxLOuaJz7n4AY/NDy742Hh9atS4TAH4w/3sCwF+iNefJrX60aj2WvMDpcgX4HQs1LhMnAHzXObcHN94hH2uRHzezEtamJevinDvunNtnZl8D0ADwPlqwFh4/jqI163HXBU7vxHIF+FXcoVDjMnEGwOs3/T7UIj9uZiWszRm0aF3M7KsAvg3gKwDG0aK1uMWPE2jBeiymwOmdWK4AfwMro1Dj8wCeNbMEbizeoRb5cTMrYW1asi5mtgbAdwA87ZybQYvWwuNHq86TJS9wulwB/hMA6+YLNV7DnQo1No8fAvhzAL8F8LJz7kiL/LiZlbA2rVqXb+DGR+BfmNnbANJozVrc6kcBrVmPvwbwTTP7DYBJ3Chw+qnWQ3eyCRExutFFiIhRgAsRMQpwISJGAS5ExCjAhYgYBbgQEfO/AWLlGfu8tDP5AAAAAElFTkSuQmCC\n",
      "text/plain": [
       "<Figure size 432x288 with 1 Axes>"
      ]
     },
     "metadata": {
      "needs_background": "light"
     },
     "output_type": "display_data"
    }
   ],
   "source": [
    "plt.imshow(train_data.data[1])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/",
     "height": 319
    },
    "id": "v89ff86OFYbK",
    "outputId": "30e4c4df-d717-4094-f4d8-7955007ea01c"
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "(3, 32, 32)\n",
      "(32, 32, 3)\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "<matplotlib.image.AxesImage at 0x7fb6549f4208>"
      ]
     },
     "execution_count": 10,
     "metadata": {
      "tags": []
     },
     "output_type": "execute_result"
    },
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAPsAAAD5CAYAAADhukOtAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4yLjIsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy+WH4yJAAAf8ElEQVR4nO2dW5BdZ5Xf/+vc+n5vdasltdSSLAkZ+YpQbOwAGQI2hJShZuKCB8IDNZ5KQSVUJg8upiqQqjwwqQDFQ0LKBNeYCcGQAQaXYTJ4jAfDGNvIN1mybFnWXepuXVunL+d+Vh7OcZXsfP+v25L6tJj9/1WpdPpb/e29zt577X36+5+1lrk7hBD/+EmttANCiNagYBciISjYhUgICnYhEoKCXYiEoGAXIiFkrmSymd0N4JsA0gD+p7t/Nfb7Pb19PjQyGrSViwt0XrVcDI67G52TzbVTW66N29LZHLWlUuH9FQtzdE65VKA2r9WozcDfWyqd5vNS4ft3V3cPndMWOR5eq1JbocDPGRCWdOtepzOKBX6sahE/YvIxM1Wr3I96PbY9Pi+T4eGUyfBz5ghfBzFVvE7cKCwUUCqVgxfPZQe7maUB/DcAHwZwAsDvzOwRd3+FzRkaGcWfff2/B20nXn2O7uvM4f3B8VqNuz+6/l3Utn7zdmobWL2e2to7wvs7sO8pOufowT3UVpnlN4l05L31DvRRW6a9Mzi+64730znXbeXHqnjxPLXt2/sCtdXr5eB4uRK+cQPAK/teprb8zFlqK5VL1FYph4Ps/Dl+o5pb4D5Wa3xfq1YNUtvAYDe11Xw2vK8KnYJiIXwn+PsnnqZzruRj/C4AB939kLuXATwM4J4r2J4QYhm5kmBfC+D4JT+faI4JIa5Bln2BzszuM7PdZrZ7Nn9xuXcnhCBcSbCfBDB+yc/rmmNvwd0fcPed7r6zp5f/rSmEWF6uJNh/B2CLmW00sxyATwF45Oq4JYS42lz2ary7V83sCwD+Fg3p7UF33xebU6vVkL8QXt0d6ucrmb4qLNd5ppfOGVu/iftR58ucqTpfpa0vhOWf4oVzdI4X+Mru2uERals/fh21jV+3gdrWrF0XHB8hkicAZLNt1FbtD6/uA8D4utV8XjW8Gl8scnlt5gJXJ86e5apAJiKzwsKr8QND/D23d3EfL+YvUFtbOw+nunPpMJsJ+5K/OEPnlEvh1XhnmhyuUGd3958D+PmVbEMI0Rr0DTohEoKCXYiEoGAXIiEo2IVICAp2IRLCFa3Gv2PcgUpY9iqXuBy2sBCWcSa28m/nzs3PU1ssGWNwOJJkkg3fG7ds2UrnvO+2ndS2djQskwFAX98qaqtkeLZcZ3tYxslEMqisGslsm+dyWImcSwDo7AhLdgP9XG7cvOl6atu//zVqg3E/SqWwlNrXO0DnRBIfcTE/TW2O8HUKxDPpLlwIX6uFBZ50wzLiYhmAerILkRAU7EIkBAW7EAlBwS5EQlCwC5EQWroa7/U6qiQRwqp8hbkt1xEcv3iWlyoaWs1Xute/myeZjIyvobYsW6aN1A+qVPnK/6uTPIFm4dAZvs0UX/V97eWXguPv3c5Xut+/673UFlvdzUfqExw7eio4nstGagPmeGLT8CquvBw7/jrfJinTNVfgak0+z6+rTJbXBuzt5UlDsXp9rLxerE5eW1v4WjTunp7sQiQFBbsQCUHBLkRCULALkRAU7EIkBAW7EAmh5dJbaSEseXR3cEmmdzCcFHLrTTfTOeObtlDbbCTx47VDx6ktvxCWT+ZmeK2wczNcXpuc4vXMeiOJMEjxBIlHf/Cj4Hj2Xn5f/8Dtd1JbNstlxdWruUwJD8tXMxfC3U8A4PkXePecTKROXlcPl+yqtbB0WJ7j5ywdeQTGur7UalwSPXeey3kphCW7WDup/v5wwlY60mZKT3YhEoKCXYiEoGAXIiEo2IVICAp2IRKCgl2IhHBF0puZHQEwC6AGoOruvOAaAEsZ2tqyQVsl3UPnFTrCjewP53mbnhd/8yy1nT/H66qdPMVrjGXT4ZSibIpnJ5VIGyQAKBa5bWwVPzWnp45SWy/JhpqdydM5Bw4f5n6MDVNbNst9HBsPt4ZaQ8YB4NgUlz1fe5nbRsa4THnkGJG8Kvyc1cvcVovU/2vPcXmwLRO+7gGgUAxvs7eXS4oZ0jLKIs/vq6Gz/zN3IqoKIa4Z9DFeiIRwpcHuAH5hZs+Z2X1XwyEhxPJwpR/j73T3k2Y2AuAxM3vV3Z+89BeaN4H7AKB/gH/VUAixvFzRk93dTzb/Pw3gJwB2BX7nAXff6e47u7rDC21CiOXnsoPdzLrMrOfN1wA+AmDv1XJMCHF1uZKP8aMAfmKNCncZAP/b3f9vbEIqlUFn52jQdnqGZ6IdPB6WXV7Zx+8tqYgsVIu0mirM8kKEaSKxFUpc1pqZ5bbZSGulIyf2U1tXB5cpt23eFjZEJMB/+PXfU9uGjRupbes23vZqaCicldXWzs9LXy+XrlJVXtxyvsSfWayFUmGGZ9/VarxIaHsHl9Dm8nybvZHMvLb2cKZauRxriRbOwKzXuWx42cHu7ocA3HS584UQrUXSmxAJQcEuREJQsAuREBTsQiQEBbsQCaGlBSfT6Qz6B8NZVAePH6DzJo+Es7I6s7zw4sV5XsxxLn+a2iwiXczMhqWymQKXajIkyw8AhkdHqK2jJyxdAcDaCS6CjBMZ5/BLv6Vz0sZluUqNZ3mdOcuLad5ww/bg+HVbNtE545Hste7bbqG2Pa8eo7ZSMVzItJSNZL2By2R15xLx1FS4vx0A5Nq4rNg3wK4DLgMXCuGMz7rz96UnuxAJQcEuREJQsAuREBTsQiQEBbsQCaGlq/Gl0jzeeCNcG+7VNw7Seacm3wiO1yJJKz19XdS2bcsEte3YvoPaJs+EV0CPnuF+rFodTvwBgA2beZJJzxBfqZ++wPfnZ8PKxbGjfMX6TKRF1fbrqQkf3hpecQeA+TmyWswX9+Flrgrse5qrCVu28TZgo2v7g+NPP/tkcBwApqZ58lKlwlfjiwXu/4VI26uO7rCPsZX1edJGLZYIoye7EAlBwS5EQlCwC5EQFOxCJAQFuxAJQcEuREJoqfQ2P5fH008+FnZklNROA7B5+w3B8Y5Im57t12+htm1b11FbrRhOJAEAT4XlpHnwhjiZbDgRAwDS6bDkAgCVKk+cmJ89T2195bA0VK05nXPsNE8aau8+yffVO0BtmzZPBMc98nwpzITrqgHAq8+8SG1e4NfBjrvuDo7fcCNPyCns5tLbGwePUFtnJ6+e3Nc/RG2N7mn/P/k8Py+lUvhYuaQ3IYSCXYiEoGAXIiEo2IVICAp2IRKCgl2IhLCo9GZmDwL4OIDT7r6jOTYI4AcAJgAcAXCvu3OdoEmlXMXp42GZ6pab/gWd19YWrk02yFUyjK3hdcTOR1r/HD/IZa1yPSyHpYyncqUzXAqpOa+hh2qsfVVYAgQAr4X3190Xrv0HAOfmeBZdKsezB+vO5bxGN+/QJD6ju52fs4k149TWnuZ+pBCuG3jDDp5x2N/PJdFHCr+gtqlJHgJrR9ZQW83CNQyzkRZm+XxYHtyfDbdKA5b2ZP8LAG8XK+8H8Li7bwHwePNnIcQ1zKLB3uy3/vbH3T0AHmq+fgjAJ66yX0KIq8zl/s0+6u6TzddTaHR0FUJcw1zx12Xd3c2M/tFkZvcBuA8AslleQ10Isbxc7pN92szGAKD5P+264O4PuPtOd9+ZybT0q/hCiEu43GB/BMBnm68/C+CnV8cdIcRysRTp7fsAPghg2MxOAPgygK8C+KGZfQ7AUQD3LmVnqVQGnd2DQVs2ouLMzIQ/OLQNcolkoco1niLv1oSOgR5qa6sb2SCX3jxyhIsVnuXV3sEnpiLtmuqp8LzuIS795JzLjekOntnmOa591i383qzGpbxUmr/nbFeO2jq6ua1aCsus505O0zlDXbwN1T0fu4vadr90hNrmIsUoi6UzwfESafEEAP094Ws/k+bnZNFgd/dPE9OHFpsrhLh20DfohEgICnYhEoKCXYiEoGAXIiEo2IVICC39lksu14ax9eFsI0vx+06xGM7wmc5z93P9PMurUuVSjUW+5VeYC2dQVZz7nsnwwpHVNLd19vIMsJGhGWrz82G5phzpUWZ17n9HRwe1pSJZh3UP769W4zJlKhsp9pnmPs7N8yxGIwUY2yLXW/4Ml+U6OsPSMQC8//Ybqe21N45S295XpoLjc3mejZgjhUzr9VgGoBAiESjYhUgICnYhEoKCXYiEoGAXIiEo2IVICC2V3twAt7C8UolIQwuzYWmlLSILzeYjhSOLvNDjQp7LOFmS9NbTxSW0VQNcqukd5Blgq/r5e6tl+qit0BY+juc38Ky3Um2S2hDJzKtVI9l3JEOwluLZiBaR3voHefZdvRbxkVxXfX38+OZ4LRbMzEZkz0pYmgWAm7evprb+nvD18+ijvLjlmelw4dZqJI70ZBciISjYhUgICnYhEoKCXYiEoGAXIiG0ttyrO0BWcDN1vrLbF/7OP8b7yPI4gHdt4vXputv5Smza+P1vPh9eiS0uXKRzOroq1LZtC1+pH9+wjtpS2Q3UNjcT9nF8bIz7cZgWB0bvIDn4AAYHeLJOJhNONorkacAjiTXtXZ3UVi1GVqDJ/rKxxCtwtWZouJva5ha4KjA/E052AYC1q8I17z7xLz9C5/z1z/4uOJ7J8IOoJ7sQCUHBLkRCULALkRAU7EIkBAW7EAlBwS5EQlhK+6cHAXwcwGl339Ec+wqAPwbwZt+aL7n7zxfbVk9XJz5w+3uCtk3X30TnnTp5Mji+dg2XrrZu2Uxtq1eNUFvauZw3S5IgSpFkEUvx7XV38USY7m4ueaVzXDrMEgmzMB9uMQQAt+7gUt7E1glqq9S5rOjkOVKtc5nM0/xYpbP8Uq0UuZ5XJ4khqQx/zlk79wOReaUKPx6ZNK9tWCuHr6tVEZnvzn/63uD4b599mc5ZypP9LwDcHRj/hrvf3Py3aKALIVaWRYPd3Z8EwPNFhRC/F1zJ3+xfMLM9ZvagmfFkYyHENcHlBvu3AGwGcDOASQBfY79oZveZ2W4z2z03z5P7hRDLy2UFu7tPu3vN3esAvg1gV+R3H3D3ne6+s7uLLzgIIZaXywp2M7s0q+KTAPZeHXeEEMvFUqS37wP4IIBhMzsB4MsAPmhmNwNwAEcA/MlSdtbZ2YH33PiuoO3dt3DprbAjLKN19fGsK17pDHDj0koqIpEMdoXriEW6P0XvpnXSmgiI1xJDROIplcLtnzZft57O6chxCbAwzzP6PBW5fCxs80h9t7pzWy1yzmItj8qF8PGo1fl7TmUi10fkjM6e4xLs0cPHqe2OO28Jji9UeD3ETiIPRpTexYPd3T8dGP7OYvOEENcW+gadEAlBwS5EQlCwC5EQFOxCJAQFuxAJoaUFJ1OpFDpIpld3O2+h1NVJ3IwU14sVNrSY9BaTeDwsldUrXEKLyUkWKXpYjYiHMXnFScHM7n6eIVit8X3V6pEqkKTFEwA4asHxVMz5GrfVMlwSdURONilwavWwfwDQFnnP2Ro/Z11FPs+nwxIgAJw5NB0cX7eNFx09mwp/GzV2ePVkFyIhKNiFSAgKdiESgoJdiISgYBciISjYhUgILZXe0uk0evrCEpBHss0WSmH5xEu8J1eJzAGA+bl5aitX+LxSKZxtVq1y6aoSyVCrRPa1EOkbtjDPs6GqJJOuZ7CPzunp433x+nuGqa09F+7nBgA11rvPIn3ZwG09PbwA57nT/DgWC2GJql7nxZUM/H3Va/ya6+3h8vGG9aPUVlgIX48eKc7Z1xOWsNMROVdPdiESgoJdiISgYBciISjYhUgICnYhEkJLV+NnZvL460f+JmirZX9N5124EE4UmLt4ls5JRXIjYiv109PhfQFAjWTXDEbaSQ0MD1FbW5of/vnz4ZZAAHDg9f3Ulp8Lrz6Pb+QtntJZroT09nD/N27kde3WjYfr9W3ctJbOGWzjWRw97dzHeqQWIdLh5JRKja90pyMtntIRH0cnIspFL1+pr3g4KSfNRQEMDobfcyaSHKYnuxAJQcEuREJQsAuREBTsQiQEBbsQCUHBLkRCWEr7p3EA3wUwika7pwfc/ZtmNgjgBwAm0GgBda+7X4htKz87h8eeeCpo61+3jc7zWlhOeuGpJ+icDet4/a7hIS4nnTwxRW1VUresc5AnkpRTPElm+gRvCfShXbdT2803vpvaFkrF4Hgqy0/14WNHqe3A629Q28t7X6C2/r5wE88//KNP0jl3vHsrteUiPbbWjY1TW5lIbxYp1harG1ghtfUAIJWJ1LXr54k8HSR5pZ7mEjETIiMlFJf0ZK8C+FN3vx7AbQA+b2bXA7gfwOPuvgXA482fhRDXKIsGu7tPuvvzzdezAPYDWAvgHgAPNX/tIQCfWC4nhRBXzjv6m93MJgDcAuAZAKPuPtk0TaHxMV8IcY2y5GA3s24APwLwRXfPX2pzdwfCxbvN7D4z221mu8tlnvgvhFhelhTsZpZFI9C/5+4/bg5Pm9lY0z4G4HRorrs/4O473X1nLse/HyyEWF4WDXZrtE/5DoD97v71S0yPAPhs8/VnAfz06rsnhLhaLCXr7Q4AnwHwspm92Bz7EoCvAvihmX0OwFEA9y62oYHBIfyrT//roK1tZAudtzAblsNef/klOmdsNZdjUpE6XR3tPIOqXA+38Nm6g/s+MMYz4haGeR20j3/0n1NbZ08Htc0T6S3SqQlV0tYKAIrV8PYA4PTp89R29PCp4HhnJz++UyfOUduRfa9TW6rIfTw0FfzAiV0f2UnnbJhYQ22xbLlUeyRNLctlOWO15ozPyVn4nMWkt0WD3d1/A4Bt4kOLzRdCXBvoG3RCJAQFuxAJQcEuREJQsAuREBTsQiSElhacNAPacuH7y4FX99J5+Yth6c1j2UllnjE0F2n/ZBHtor0tnGtUWeDtmC6e4T5OH+NZb3/zt+HCnABwYTayv7mLwfGeXi559Q2EW3IBQFekUOKJE2F5DQBGhsOFJdt7uRT565/x93z+9T3UVivzFlsHp8IFRE9EWmht2c6l1L7eTm4b4C22Ojp51ltfV/i6yrbz4pGdneHz4s6vXz3ZhUgICnYhEoKCXYiEoGAXIiEo2IVICAp2IRJCS6W3erWC2XNhGe2XP/0ZnXd86kRwPFUJZ6EBwJ49eWqLpQZVqzyrCSTT6LFHf0mn5LJcurr5lluprZzrobZ8aYHaDh0LZ3mdO8f7w5WLPOvt1NQRajt8hG9z5y3vCY7/28//ezrn2ad/S23VizwjLl/iRVEK4ZoqOLSby56/fm6S2royXObL5rhUlm7j10EPkd7WbZigc+75w08Fx8tV/vzWk12IhKBgFyIhKNiFSAgKdiESgoJdiITQ0tX4bDaHsdGxoG3LxEY6zxFeLc5EWiulIyvuqTS/x3mdJ67k2rvChixPclizJpwQAgAfvOsuauvpjCRctPPada/sDdflO3CQt3FavXaC2oqRtkvpDu7j3gOvBsdfOXCAzumc2E5tp07x9zzQz20juXBduM5uXsfv/BRvh3Xu5EFqO3M2nHQDAMVaJGmLFAicnOHh+b4PhedUedk6PdmFSAoKdiESgoJdiISgYBciISjYhUgICnYhEsKi0puZjQP4LhotmR3AA+7+TTP7CoA/BnCm+atfcvefx7ZVrVZx/ky4ZdBt/+R9dN77PvCB4HhbG088yETktVj7p3qkFVIa4f1VylzvKJR50sq5E4ep7XyRJ1ycP8vbLh0iEtup0+EEJADoHuHtjtDGZUXLcemtXA0npzz2q9/QORs230Bt44NcwmxP8cu4kyQilYq8Bt2h/D5q6+7htfxqzpOopi7MUdvw8ERwfKHCr8Vf/urZ4PjsLK+vuBSdvQrgT939eTPrAfCcmT3WtH3D3f/rErYhhFhhltLrbRLAZPP1rJntB8Bvs0KIa5J39De7mU0AuAXAM82hL5jZHjN70Mz415iEECvOkoPdzLoB/AjAF909D+BbADYDuBmNJ//XyLz7zGy3me2eneN/JwkhlpclBbuZZdEI9O+5+48BwN2n3b3m7nUA3wawKzTX3R9w953uvrOnm1dfEUIsL4sGuzVapHwHwH53//ol45dmtHwSAG/pIoRYcZayGn8HgM8AeNnMXmyOfQnAp83sZjTkuCMA/mSxDaVShi7StuZcvkjnvbDnueD4yAhfJhgdGaa2SoXLWhcuzFAbimEfM3W+vbUbuaw1PsA/6Zw8wOugzc/xmmsjo6uD451D/XROup3LSQsFfl7GxtZT29SpcN3As+fC7akAYGxNpC1XpNXXXIkff2TC11ulzuXStg6S3QigLZJNWT53htqQCteZA4BRknVYLvEWZuxw8KO0tNX43wAIvcOopi6EuLbQN+iESAgKdiESgoJdiISgYBciISjYhUgILS04mTKgLRvO5CkVueT11FOPB8e9wmWh3k5eULBS4dlJxQJvKZUh98YNE+N0zo7brqe2zeu5LDdzPCxdAcDUhbPUlusIS02bh8KSHACcOcMzsm7YtoPa3n3DNmp7+H99NzieQbgAJABU5vn5LJe5zWNVFtvD5zrWjmli4yZqO338Nb6vFM/C7Oji+9u+fWtwvLjAz8v42Ehw/Fc5LvHpyS5EQlCwC5EQFOxCJAQFuxAJQcEuREJQsAuREFoqvdXrdSwUSAHGSBHIuz768fD2yjxLKh2R1+o1XsjP01w+SWfCslF7Fy+8ODXDpbzZGd737HyB+2/tvAjkay8eCo6f+y3PyNq0kUto771uC7WVIxlxHbmw1OSRjMNYhl0qzS9V0ioNAFCokz6BNX58N6zj0ltx7hy1Xd/Ls+Wefe4Fajt1NCznFeb59e0LF4Lj5RLPiNSTXYiEoGAXIiEo2IVICAp2IRKCgl2IhKBgFyIhtDbrLWXo6g7LV32RSnk9q8JZQaWIzNAeuY/ljGdeeQfPlmvrDM+rF3l20uxsntrSnbzQ48hmXiBycyfPenv9cLjXG4xLillSBBQATk4eo7ahYV7wk9nKBS4nlUq8GOV8JCOuFMkOq5TCUm+mnculo2tWUdvRyWlqmz5Gjj2A4hx/b2/sezE4PjTE/fCBwfB4pDCnnuxCJAQFuxAJQcEuREJQsAuREBTsQiSERVfjzawdwJMA2pq//1fu/mUz2wjgYQBDAJ4D8Bl35/1qANTrRSzMkuSPOr/vZK07OD49zVc4X3/lCLW1Z/iKe66Pr4IPk3ZTa4b76JxMJMFnqG+I2iK5OigWwkkQADAyEl7hX7smvHoLAJNTU9R24MB+apsob6Q2ppTMzvJztrDAV7rzF7mqEVuNr5XDiUjpNp60sm8vbx0Wa8k0MjJKbWtv5LX8RlaF5w2v4nUD24n/j//DE3TOUp7sJQB/4O43odGe+W4zuw3AnwP4hrtfB+ACgM8tYVtCiBVi0WD3Bm/eOrPNfw7gDwD8VXP8IQCfWBYPhRBXhaX2Z083O7ieBvAYgDcAzLj7m0nBJwCsXR4XhRBXgyUFu7vX3P1mAOsA7ALwrqXuwMzuM7PdZrZ7dpYUrhBCLDvvaDXe3WcAPAHgdgD9ZvbmAt86ACfJnAfcfae77+zp4V9RFEIsL4sGu5mtMrP+5usOAB8GsB+NoP+j5q99FsBPl8tJIcSVs5REmDEAD5lZGo2bww/d/VEzewXAw2b2nwG8AOA7i26p7qiTNj6pyH0nUwkncfSSVlIA8NzTv6K2qWmeSGJZnhSya9d7guN33r6Tzrl4kUtNe55/htrmizzx48Cx49R26MiR4Hhhgf8J5c6LuLX38mSMfH6W2mZJi6r5PJcNI6XkkElza1/kE+OajWF5cGBojM4ZWcMlrzW33EBtg5EadLlYbUNmiyQvwcPxkoq0oFo02N19D4BbAuOH0Pj7XQjxe4C+QSdEQlCwC5EQFOxCJAQFuxAJQcEuREKwWM2qq74zszMAjjZ/HAbANbDWIT/eivx4K79vfmxw96Be2tJgf8uOzXa7Oxeo5Yf8kB9X1Q99jBciISjYhUgIKxnsD6zgvi9FfrwV+fFW/tH4sWJ/swshWos+xguREFYk2M3sbjN7zcwOmtn9K+FD048jZvaymb1oZrtbuN8Hzey0me29ZGzQzB4zs9eb//PeSsvrx1fM7GTzmLxoZh9rgR/jZvaEmb1iZvvM7N81x1t6TCJ+tPSYmFm7mT1rZi81/fhPzfGNZvZMM25+YBbpYxbC3Vv6D0AajbJWmwDkALwE4PpW+9H05QiA4RXY7/sB3Apg7yVj/wXA/c3X9wP48xXy4ysA/kOLj8cYgFubr3sAHABwfauPScSPlh4TNLJ9u5uvswCeAXAbgB8C+FRz/H8A+DfvZLsr8WTfBeCgux/yRunphwHcswJ+rBju/iSA828bvgeNwp1Aiwp4Ej9ajrtPuvvzzdezaBRHWYsWH5OIHy3FG1z1Iq8rEexrAVxafWEli1U6gF+Y2XNmdt8K+fAmo+4+2Xw9BYAXIV9+vmBme5of85f9z4lLMbMJNOonPIMVPCZv8wNo8TFZjiKvSV+gu9PdbwXwUQCfN7P3r7RDQOPOjsaNaCX4FoDNaPQImATwtVbt2My6AfwIwBfd/S1dIVp5TAJ+tPyY+BUUeWWsRLCfBDB+yc+0WOVy4+4nm/+fBvATrGzlnWkzGwOA5v+nV8IJd59uXmh1AN9Gi46JmWXRCLDvufuPm8MtPyYhP1bqmDT3/Y6LvDJWIth/B2BLc2UxB+BTAB5ptRNm1mVmPW++BvARAHvjs5aVR9Ao3AmsYAHPN4OrySfRgmNiZoZGDcP97v71S0wtPSbMj1Yfk2Ur8tqqFca3rTZ+DI2VzjcA/NkK+bAJDSXgJQD7WukHgO+j8XGwgsbfXp9Do2fe4wBeB/B3AAZXyI+/BPAygD1oBNtYC/y4E42P6HsAvNj897FWH5OIHy09JgBuRKOI6x40biz/8ZJr9lkABwH8HwBt72S7+gadEAkh6Qt0QiQGBbsQCUHBLkRCULALkRAU7EIkBAW7EAlBwS5EQlCwC5EQ/h+CqIklWmKmUgAAAABJRU5ErkJggg==\n",
      "text/plain": [
       "<Figure size 432x288 with 1 Axes>"
      ]
     },
     "metadata": {
      "needs_background": "light",
      "tags": []
     },
     "output_type": "display_data"
    }
   ],
   "source": [
    "# 输出图像\n",
    "temp = train_data[1][0].numpy()\n",
    "print(temp.shape)\n",
    "temp = temp.transpose(1, 2, 0) # 调换temp.shape的三个维度0，1，2位置\n",
    "print(temp.shape)\n",
    "plt.imshow(temp)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "metadata": {
    "id": "WRt0q2P6FYbO"
   },
   "outputs": [],
   "source": [
    "# 超参数定义\n",
    "EPOCH = 10\n",
    "BATCH_SIZE = 128\n",
    "LR = 0.001\n",
    "\n",
    "# 使用DataLoader进行分批\n",
    "train_loader = DataLoader(dataset=train_data, batch_size=BATCH_SIZE, shuffle=True)\n",
    "test_loader = DataLoader(dataset=train_data, batch_size=BATCH_SIZE) # 没有shuffle，因为不进行训练\n",
    "\n",
    "# 使用ResNet\n",
    "model = torchvision.models.resnet18(pretrained=True) # 预训练,models注意有s\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/",
     "height": 34
    },
    "id": "vmokHSe5HNd3",
    "outputId": "f550927d-a9e7-464a-d185-ff68b8044954"
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "True"
      ]
     },
     "execution_count": 12,
     "metadata": {
      "tags": []
     },
     "output_type": "execute_result"
    }
   ],
   "source": [
    "torch.cuda.is_available()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/",
     "height": 1000
    },
    "id": "xBBPdXmBFYbR",
    "outputId": "453701ad-4d7a-428f-9ab3-1262b12e570f"
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "ResNet(\n",
       "  (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)\n",
       "  (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "  (relu): ReLU(inplace=True)\n",
       "  (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)\n",
       "  (layer1): Sequential(\n",
       "    (0): BasicBlock(\n",
       "      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
       "      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "      (relu): ReLU(inplace=True)\n",
       "      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
       "      (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "    )\n",
       "    (1): BasicBlock(\n",
       "      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
       "      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "      (relu): ReLU(inplace=True)\n",
       "      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
       "      (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "    )\n",
       "  )\n",
       "  (layer2): Sequential(\n",
       "    (0): BasicBlock(\n",
       "      (conv1): Conv2d(64, 128, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)\n",
       "      (bn1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "      (relu): ReLU(inplace=True)\n",
       "      (conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
       "      (bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "      (downsample): Sequential(\n",
       "        (0): Conv2d(64, 128, kernel_size=(1, 1), stride=(2, 2), bias=False)\n",
       "        (1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "      )\n",
       "    )\n",
       "    (1): BasicBlock(\n",
       "      (conv1): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
       "      (bn1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "      (relu): ReLU(inplace=True)\n",
       "      (conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
       "      (bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "    )\n",
       "  )\n",
       "  (layer3): Sequential(\n",
       "    (0): BasicBlock(\n",
       "      (conv1): Conv2d(128, 256, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)\n",
       "      (bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "      (relu): ReLU(inplace=True)\n",
       "      (conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
       "      (bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "      (downsample): Sequential(\n",
       "        (0): Conv2d(128, 256, kernel_size=(1, 1), stride=(2, 2), bias=False)\n",
       "        (1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "      )\n",
       "    )\n",
       "    (1): BasicBlock(\n",
       "      (conv1): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
       "      (bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "      (relu): ReLU(inplace=True)\n",
       "      (conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
       "      (bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "    )\n",
       "  )\n",
       "  (layer4): Sequential(\n",
       "    (0): BasicBlock(\n",
       "      (conv1): Conv2d(256, 512, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)\n",
       "      (bn1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "      (relu): ReLU(inplace=True)\n",
       "      (conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
       "      (bn2): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "      (downsample): Sequential(\n",
       "        (0): Conv2d(256, 512, kernel_size=(1, 1), stride=(2, 2), bias=False)\n",
       "        (1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "      )\n",
       "    )\n",
       "    (1): BasicBlock(\n",
       "      (conv1): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
       "      (bn1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "      (relu): ReLU(inplace=True)\n",
       "      (conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
       "      (bn2): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "    )\n",
       "  )\n",
       "  (avgpool): AdaptiveAvgPool2d(output_size=(1, 1))\n",
       "  (fc): Linear(in_features=512, out_features=1000, bias=True)\n",
       ")"
      ]
     },
     "execution_count": 19,
     "metadata": {
      "tags": []
     },
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# 定义损失函数为交叉熵，二分类常用损失函数\n",
    "criterion = nn.CrossEntropyLoss()\n",
    "\n",
    "# 定义优化器为adam，引入二阶动量\n",
    "optimizer = optim.Adam(model.parameters(), lr=LR)\n",
    "\n",
    "# 定义device  gpu\n",
    "device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')\n",
    "model.to(device)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/",
     "height": 194
    },
    "id": "HHKbPLBPFYbX",
    "outputId": "694919e1-323d-4abf-e5e4-6edbd77ff44c"
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch1 loss: 0.8631 time 20.5853\n",
      "epoch2 loss: 0.5325 time 20.7143\n",
      "epoch3 loss: 0.5892 time 20.9142\n",
      "epoch4 loss: 0.3453 time 21.0151\n",
      "epoch5 loss: 0.2485 time 21.0880\n",
      "epoch6 loss: 0.3284 time 21.0842\n",
      "epoch7 loss: 0.3116 time 21.0700\n",
      "epoch8 loss: 0.2450 time 21.0525\n",
      "epoch9 loss: 0.3559 time 21.0755\n",
      "epoch10 loss: 0.2687 time 21.0728\n"
     ]
    }
   ],
   "source": [
    "# 训练，google GPU，每个epoch比cpu快了3.5倍\n",
    "for epoch in range(EPOCH):\n",
    "    start_time = time.time()\n",
    "    for i, data in enumerate(train_loader):\n",
    "        inputs, labels = data\n",
    "        inputs, labels = inputs.to(device), labels.to(device) # 如果有gpu将数据放到gpu里\n",
    "        # 前向传播\n",
    "        outputs = model(inputs)\n",
    "        # 计算损失函数\n",
    "        loss = criterion(outputs, labels)\n",
    "        # 清空上一轮梯度\n",
    "        optimizer.zero_grad()\n",
    "        # 反向传播\n",
    "        loss.backward()\n",
    "        # 知道梯度方向以后，参数更新\n",
    "        optimizer.step()\n",
    "    print('epoch{} loss:{: .4f} time{: .4f}'.format(epoch+1, loss.item(), time.time()-start_time))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 21,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/",
     "height": 34
    },
    "id": "UQDPHN4wFYba",
    "outputId": "63fed830-68cd-46ac-d1cd-6eb5c89c52f2"
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "cifar10_resnet_colab_epoch10.pt saved \n"
     ]
    }
   ],
   "source": [
    "# 保存训练模型\n",
    "file_name = 'cifar10_resnet_colab_epoch10.pt'\n",
    "torch.save(model, file_name)\n",
    "print(file_name + ' saved ')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 22,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/",
     "height": 34
    },
    "id": "EDMNr8m-FYbd",
    "outputId": "399b6e44-3df5-4ac1-da6e-85fdde85b621"
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "10000张测试图像的准确率： 96.6120\n"
     ]
    }
   ],
   "source": [
    "# 测试\n",
    "model = torch.load(file_name)\n",
    "model.eval()# 不启用 BatchNormalization 和 Dropout,固定BN和dropout层，使得偏置参数不随着发生变化。因为当batchsize小时，如果没有固定，会对图像的失真有很大的影响\n",
    "\n",
    "correct, total = 0, 0\n",
    "for data in test_loader:\n",
    "    images, labels = data\n",
    "    images, labels = images.to(device), labels.to(device)\n",
    "    # 前向传播\n",
    "    out = model(images)\n",
    "    # 预测结果\n",
    "    _, predicted = torch.max(out.data, 1)\n",
    "    # 判断预测结果与实际结果是否一致\n",
    "    total += labels.size(0)\n",
    "    correct += (predicted==labels).sum().item()\n",
    "    \n",
    "# 输出识别准确率\n",
    "print('10000张测试图像的准确率：{: .4f}'.format(100*correct/total))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 103,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "torch.return_types.max(\n",
       "values=tensor([10.5876, 13.8441,  9.7160, 13.9375, 21.3442, 12.6578, 17.7501, 12.2050,\n",
       "        16.3147, 13.3066, 14.7981, 12.0003, 11.9597, 17.4236, 12.1771, 15.5576]),\n",
       "indices=tensor([7, 5, 8, 0, 8, 4, 7, 0, 3, 3, 3, 0, 3, 5, 0, 7]))"
      ]
     },
     "execution_count": 103,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# values是给定维中的每行张量的最大值；indices是找到的每个最大值的索引位置。\n",
    "# 最大值相当于softmax的概率最大的位置的索引，就是分类的标签\n",
    "torch.max(out.data,1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "dtp-CFSsI-xN"
   },
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "accelerator": "GPU",
  "colab": {
   "collapsed_sections": [],
   "name": "Action1_.ipynb",
   "provenance": [],
   "toc_visible": true
  },
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.9"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
